{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitdlenvconda1763085c40a848e3ae8a8d3a98a401cd",
   "display_name": "Python 3.7.7 64-bit ('dl_env': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch 自动求梯度\n",
    "## 基本常识\n",
    "pytorch 会自动将一个变量的变换流程记录为 DAG，然后根据链式法则和反向传播算法求梯度\n",
    "求梯度的函数返回是一个常数，一般不允许 Tensor 求 Tensor 的梯度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([[6., 6., 6., 6.],\n        [6., 6., 6., 6.],\n        [6., 6., 6., 6.]])\n"
    }
   ],
   "source": [
    "# 求 out1 相对 x 的梯度，结果记录在 x.grad 里面\n",
    "import torch\n",
    "\n",
    "x = torch.ones(3, 4, requires_grad=True)\n",
    "out1 = (3*x*x).sum()\n",
    "out1.backward()\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([[4., 4., 4., 4.],\n        [4., 4., 4., 4.],\n        [4., 4., 4., 4.]])\n"
    }
   ],
   "source": [
    "# 求 out1 相对 x 的梯度，结果记录在 x.grad 里面\n",
    "# 必须先清空 x 的梯度，否则会叠加\n",
    "x.grad.data.zero_()\n",
    "out2 = (2*x*x).sum()\n",
    "out2.backward()\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 如何不记录梯度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([[8., 8., 8., 8.],\n        [8., 8., 8., 8.],\n        [8., 8., 8., 8.]])\n"
    }
   ],
   "source": [
    "# 实际中可能会有一部分计算不想计入求梯度过程中，比如预测过程，此时需要设置\n",
    "x = torch.ones(3, 4, requires_grad=True)\n",
    "y = 3*x*x + 2*x\n",
    "with torch.no_grad():\n",
    "    y += 2*x\n",
    "out1 = y.sum()\n",
    "out1.backward()\n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}